package com.mamingchao.basic.tree.binaryTree;

import java.util.Objects;

/**
 * 自实现的SB树
 * 叔叔的子节点数，不小于侄子的子节点数
 * 即任何一个节点的子节点数，都不小于兄弟节点的子节点的所有子节点数
 */
public class SizeBalancedTree {

    // 在SizeBalancedTree的定义里，该TreeNode里的value为
    private TreeNode root;

    public SizeBalancedTree(TreeNode root){
        this.root = root;
    }

    public static void main(String[] args) {
        TreeNode root10 = new TreeNode(10, "0");
        TreeNode node4 = new TreeNode(4, "0");
        TreeNode node6 = new TreeNode(6, "0");
        TreeNode node7 = new TreeNode(7, "0");
        TreeNode node8 = new TreeNode(8, "0");
        TreeNode node9 = new TreeNode(9, "0");
        TreeNode node11 = new TreeNode(11, "0");

        SizeBalancedTree tree = new SizeBalancedTree(root10);
        tree.insert(node4);
        tree.insert(node6);
        // tree.insert(node7);
        // tree.insert(node8);
        // tree.insert(node9);
        // tree.insert(node11);

        BinaryTreeIterator.preOrderRecurse(root10);
        
    }


    /**
     * 向SB树中插入一个新节点
     * 插入后需要从插入位置开始向上检查每个节点是否满足SB树的定义
     * 如果不满足则需要进行旋转操作来重新平衡
     * @param node 要插入的新节点
     * @return 插入后的根节点
     */

    public boolean insert(TreeNode node) {
        TreeNode startNode = root;
        // 不支持重复的key的插入
        if (Objects.nonNull(this.find(root, node.getKey()) )) {
            System.out.println("key is exist");
            return false;
        };

        while (startNode != null){
            
            // 插入节点，从上到下的每个节点，size都增加1
            startNode.setValue(String.valueOf(Integer.valueOf(startNode.getValue()) + 1));

            if (node.getKey() < startNode.getKey()) {
                if (Objects.isNull(startNode.leftChild)) {
                    startNode.leftChild = node;
                    break;
                }
                startNode = startNode.leftChild;
            } else {
                if (Objects.isNull(startNode.rightChild)) {
                    startNode.rightChild = node;
                    break;
                }
                startNode = startNode.rightChild;
            }

        }
        
        rebalance(root);
        return true;
    }

    /**
     * 删除key对应的节点
     * 1. 如果key不存在，直接返回null
     * 2. 如果key存在：
     *    - 如果是叶子节点，直接删除
     *    - 如果只有一个子节点，用子节点替换当前节点
     *    - 如果有两个子节点，找到后继节点替换当前节点，然后删除后继节点
     * 3. 删除节点后需要从删除点开始向上调整平衡性
     * @param key 要删除的节点的key值
     * @return 被删除的节点，如果key不存在返回null
     */
    public boolean delete(TreeNode childNode, int key) {
      
        if (Objects.isNull(childNode)) {
            return false;
        }

        // 如果key不存在，则报错
        if (Objects.nonNull(this.find(root, key) )) {
            System.out.println("key is not exist");
            return false;
        };
                

        while (Objects.nonNull(childNode)) {

            // 找到了节点，根据不同情况执行删除动作删除
            if (childNode.getKey() == key) {
    
                // 插入节点，从上到下的每个节点，size都增加1
                childNode.setValue(String.valueOf(Integer.valueOf(childNode.getValue()) + 1));
                // 如果命中了节点，并且该节点是叶子节点，直接删掉
                if (childNode.leftChild == null && childNode.rightChild== null) {
                    childNode = null;
                    rebalance(root);
                } else if(childNode.rightChild != null) {
                    // 当前节点的右孩子都在，使用当前节点右孩子的最左孩子 替换当前节点
                    replaceDeletedNode(childNode, false);
                
                } else if (childNode.leftChild != null){
    
                    // 当前节点的左孩子都在，使用当前节点左孩子的最右孩子 替换当前节点；
                    replaceDeletedNode(childNode, true);
                }
            } else {
                // 搜索二叉树，没找到key，继续寻找
                if (key < childNode.getKey()) {
                    childNode = childNode.leftChild;
                } else {
                    childNode = childNode.rightChild;;
                }
            }
        }

        return true;

    }


    private void replaceDeletedNode(TreeNode node, boolean LR) {
        // 当前节点的左右孩子都在，使用当前节点左孩子的最右孩子 替换当前节点；
        TreeNode leftChild = node.leftChild;
        TreeNode rightChild  = node.rightChild;
        TreeNode predecessor = null;
        if (LR) {
            predecessor = this.searchLR(node);
            node = predecessor.rightChild;
            node.leftChild = leftChild;
            node.rightChild = rightChild;
            predecessor.rightChild = null;
        } else {
            predecessor = this.searchRL(node);
            node = predecessor.leftChild;
            node.leftChild = leftChild;
            node.rightChild = rightChild;
            predecessor.leftChild = null;
        }
        rebalance(node);
    }

    /**
     * 查询指定节点左孩子的最右子孩子节点的父节点
     * @param node
     * @return
     */
    private TreeNode searchLR(TreeNode node) {
        TreeNode predecessor = node.leftChild;
        if (predecessor == null) {
            return node;
        } else {
            while (predecessor.rightChild != null && predecessor.rightChild.rightChild != null) {
                predecessor = predecessor.rightChild;
            }
            return predecessor;
        }
    }

     /**
     * 查询指定节点左孩子的最右子孩子节点的父节点
     * @param node
     * @return
     */
    private TreeNode searchRL(TreeNode node) {
        TreeNode predecessor = node.rightChild;
        if (predecessor == null) {
            return node;
        } else {

            while  (predecessor.leftChild != null && predecessor.leftChild.leftChild != null) {
                predecessor = predecessor.leftChild;
            }
            return predecessor;
        }
    }

    /**
     * 获取指定节点的子树大小
     * @param node 要查询的节点
     * @return 该节点子树的大小(包含自身)
     */
    private int getSize(TreeNode node) {
        if (Objects.isNull(node)) {
            return 0;
        }

        return getSize(node.leftChild) + getSize(node.rightChild) + 1;

        // return Objects.isNull(node) ？ 0 : node.getSize();
    }

    /**
     * 获取指定节点的子树大小
     * @param node 要查询的节点
     * @return 该节点子树的大小(包含自身)
     */
    private int updateSize(TreeNode node) {
       
        int size = getSize(node.leftChild) + getSize(node.rightChild) + 1;
        node.setValue(String.valueOf(size));
        return size;
    }
    

    /**
     * 在以node为根的子树中查找key对应的节点
     * @param node 子树的根节点
     * @param key 要查找的key值
     * @return 找到的节点，如果不存在返回null
     */
    private TreeNode find(TreeNode node, int key) {
        if (Objects.isNull(node)) {
            return null;
        }
        if (node.getKey() == key) {
            return node;
        }
        if (key < node.getKey()) {
            return find(node.leftChild, key);
        } else {
            return find(node.rightChild, key);
        }
    }

    /**
     * 从startNode开始，向上逐个检查每个节点，检查是否满足SizeBalancedTree的定义
     * 如果 不满足，跟进 是LL RR LR RL  四种情况，分别做处理
     * @param startNode
     */
    private void rebalance(TreeNode startNode) {
        if (startNode == null) {
            return;
        }

        // 这里是 维持左右平衡的一个判断标准，这个可以根据要求和需求自定义
        // 比如 左边比右边的大小差值大于1，则需要调整
        // 比如 左边比右边的二倍还大，等判断条件
        if (getSize(startNode.leftChild) < getSize(startNode.rightChild) + 1) {

            if (startNode.rightChild != null && getSize(startNode.rightChild.leftChild) > getSize(startNode.rightChild.rightChild)) {
                
                startNode.rightChild = rightRotate(startNode.rightChild);
            }
            
            leftRotate(startNode);
        }

        if (getSize(startNode.leftChild) > getSize(startNode.rightChild) + 1) {

            if (startNode.leftChild != null && getSize(startNode.leftChild.leftChild) < getSize(startNode.leftChild.rightChild)) {
                
                startNode.leftChild = leftRotate(startNode.leftChild);
            }
            
            rightRotate(startNode);
        }
    }


    /**
     * 左旋操作；基于keyNode进行左旋操作
     * 30
      /  \
    20    40
   /  \     \
  10   25    50
     * @param keyNode
     */
    private TreeNode leftRotate(TreeNode keyNode) {
        // keyNode 必须有右孩子
        if (Objects.isNull(keyNode.rightChild)) {
            return null;
        }

        TreeNode rightChild = keyNode.rightChild;
        // 比如30左旋，30的右孩子设置为40的左孩子
        keyNode.rightChild = rightChild.leftChild;
        // 40的左孩子设置为30
        rightChild.leftChild = keyNode;

        return rightChild;
    }
    
    /**
     * 右旋操作；基于keyNode进行右旋操作
     * 30
      /  \
    20    40
   /  \     \
  10   25    50
     * @param keyNode
     */
    private TreeNode rightRotate(TreeNode keyNode) {
        // keyNode 必须有左孩子
        if (Objects.isNull(keyNode.leftChild)) {
            return null;
        }

        // 比如30右旋，30的左孩子设置为20的右孩子
        // keyNode.leftChild 为 20
        // keyNode 为 30
        TreeNode leftChild = keyNode.leftChild;
        keyNode.leftChild  = leftChild.rightChild;
        // 20的右孩子设置为30
        leftChild.rightChild = keyNode;
        return leftChild;
    } 

    /**
     * 在以node为根的子树中查找key的前驱节点
     * @param node 子树的根节点
     * @param key 要查找前驱的key值
     * @return key的前驱节点，如果不存在返回null
     */
    private TreeNode findPredecessor(TreeNode node, int key) {
        if (Objects.isNull(node)) {
            return null;
        }
        if (node.getKey() >= key) {
            return findPredecessor(node.leftChild, key);
        }
        TreeNode right = findPredecessor(node.rightChild, key);
        return right != null ? right : node;
    }

    /**
     * 在以node为根的子树中查找key的后继节点
     * @param node 子树的根节点
     * @param key 要查找后继的key值
     * @return key的后继节点，如果不存在返回null
     */
    private TreeNode findSuccessor(TreeNode node, int key) {
        if (Objects.isNull(node)) {
            return null;
        }
        if (node.getKey() <= key) {
            return findSuccessor(node.rightChild, key);
        }
        TreeNode left = findSuccessor(node.leftChild, key);
        return left != null ? left : node;
    }

}
